# Copyright (c) Alibaba, Inc. and its affiliates.
import math
from functools import partial, reduce

import torch
import torch.nn as nn
from mmcv.cnn import kaiming_init, normal_init
from packaging import version

from easycv.models.utils import GeMPooling, ResLayer
from ..backbones.hrnet import Bottleneck, get_expansion
from ..registry import NECKS
from ..utils import (ConvModule, _init_weights, build_conv_layer,
                     build_norm_layer)


@NECKS.register_module
class LinearNeck(nn.Module):
    '''Linear neck: fc only
    '''

    def __init__(self,
                 in_channels,
                 out_channels,
                 with_avg_pool=True,
                 with_norm=False):
        super(LinearNeck, self).__init__()
        self.with_avg_pool = with_avg_pool
        if with_avg_pool:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_channels, out_channels)
        self.with_norm = with_norm

    def init_weights(self, init_linear='normal'):
        _init_weights(self, init_linear)

    def forward(self, x):
        assert len(x) == 1 or len(x) == 2  # to fit vit model
        x = x[0]
        if self.with_avg_pool:
            x = self.avgpool(x)

        x = self.fc(x.view(x.size(0), -1))
        if self.with_norm:
            x = nn.functional.normalize(x, p=2, dim=1)
        return [x]


@NECKS.register_module
class RetrivalNeck(nn.Module):
    '''RetrivalNeck: refer, Combination of Multiple Global Descriptors for Image Retrieval
         https://arxiv.org/pdf/1903.10663.pdf

       CGD feature : only use avg pool + gem pooling + max pooling, by pool -> fc -> norm -> concat -> norm
       Avg feature : use avg pooling, avg pool -> syncbn -> fc

       len(cgd_config) > 0: return  [CGD, Avg]
       len(cgd_config) = 0 : return [Avg]
    '''

    def __init__(
        self,
        in_channels,
        out_channels,
        with_avg_pool=True,
        cdg_config=[
            'G', 'M'
        ]):  # with_avg_pool=True, with_gem_pool=True,  with_norm=False):
        """ Init RetrivalNeck, faceid neck doesn't pool for input feature map, doesn't support dynamic input

        Args:
            in_channels: Int - input feature map channels
            out_channels: Int - output feature map channels
            with_avg_pool: bool do avg pool for BNneck or not
            cdg_config : list('G','M','S'), to configure output feature, CGD =  [gempooling] + [maxpooling] + [meanpooling],
                if len(cgd_config) > 0: return  [CGD, Avg]
                if len(cgd_config) = 0 : return [Avg]
        """
        super(RetrivalNeck, self).__init__()
        self.with_avg_pool = with_avg_pool
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Linear(in_channels, out_channels, bias=False)
        self.dropout = nn.Dropout(p=0.3)
        _, self.bn_output = build_norm_layer(dict(type='BN'), in_channels)
        # dict(type='SyncBN'), in_channels)

        self.cdg_config = cdg_config
        cgd_length = int(len(cdg_config))
        if cgd_length > 0:
            assert (out_channels % cgd_length == 0)
            if 'M' in cdg_config:
                self.mpool = nn.AdaptiveMaxPool2d((1, 1))
                self.fc_mx = nn.Linear(
                    in_channels, int(out_channels / cgd_length), bias=False)
            if 'S' in cdg_config:
                self.spool = nn.AdaptiveAvgPool2d((1, 1))
                self.fc_sx = nn.Linear(
                    in_channels, int(out_channels / cgd_length), bias=False)
            if 'G' in cdg_config:
                self.gpool = GeMPooling()
                self.fc_gx = nn.Linear(
                    in_channels, int(out_channels / cgd_length), bias=False)

    def init_weights(self, init_linear='normal'):
        _init_weights(self, init_linear)

    def forward(self, x):
        assert len(x) == 1 or len(x) == 2  # to fit vit model
        x = x[0]

        # BNNeck with avg pool
        if self.with_avg_pool:
            ax = self.avgpool(x)
        else:
            ax = x
        cls_x = self.bn_output(ax)
        cls_x = self.fc(cls_x.view(x.size(0), -1))
        cls_x = self.dropout(cls_x)

        if len(self.cdg_config) > 0:
            concat_list = []
            if 'S' in self.cdg_config:
                sx = self.spool(x).view(x.size(0), -1)
                sx = self.fc_sx(sx)
                sx = nn.functional.normalize(sx, p=2, dim=1)
                concat_list.append(sx)

            if 'G' in self.cdg_config:
                gx = self.gpool(x).view(x.size(0), -1)
                gx = self.fc_gx(gx)
                gx = nn.functional.normalize(gx, p=2, dim=1)
                concat_list.append(gx)

            if 'M' in self.cdg_config:
                mx = self.mpool(x).view(x.size(0), -1)
                mx = self.fc_mx(mx)
                mx = nn.functional.normalize(mx, p=2, dim=1)
                concat_list.append(mx)

            concatx = torch.cat(concat_list, dim=1)
            concatx = concatx.view(concatx.size(0), -1)
            # concatx = nn.functional.normalize(concatx, p=2, dim=1)
            return [concatx, cls_x]
        else:
            return [cls_x]


@NECKS.register_module
class FaceIDNeck(nn.Module):
    '''FaceID neck: Include BN, dropout, flatten, linear, bn
    '''

    def __init__(self,
                 in_channels,
                 out_channels,
                 map_shape=1,
                 dropout_ratio=0.4,
                 with_norm=False,
                 bn_type='SyncBN'):
        """ Init FaceIDNeck, faceid neck doesn't pool for input feature map, doesn't support dynamic input

        Args:
            in_channels: Int - input feature map channels
            out_channels: Int - output feature map channels
            map_shape: Int or list(int,...), input feature map (w,h) or w when w=h,
            dropout_ratio : float, drop out ratio
            with_norm : normalize output feature or not
            bn_type : SyncBN or BN
        """
        super(FaceIDNeck, self).__init__()

        if version.parse(torch.__version__) < version.parse('1.4.0'):
            self.expand_for_syncbn = True
        else:
            self.expand_for_syncbn = False

        # self.bn_input = nn.BatchNorm2d(in_channels)
        _, self.bn_input = build_norm_layer(dict(type=bn_type), in_channels)
        self.dropout = nn.Dropout(p=dropout_ratio)

        if type(map_shape) == list:
            in_ = int(reduce(lambda x, y: x * y, map_shape) * in_channels)
        else:
            assert type(map_shape) == int
            in_ = in_channels * map_shape * map_shape

        self.fc = nn.Linear(in_, out_channels)
        self.with_norm = with_norm
        self.syncbn = bn_type == 'SyncBN'
        if self.syncbn:
            _, self.bn_output = build_norm_layer(
                dict(type=bn_type), out_channels)
        else:
            self.bn_output = nn.BatchNorm1d(out_channels)

    def _forward_syncbn(self, module, x):
        assert x.dim() == 2
        if self.expand_for_syncbn:
            x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
        else:
            x = module(x)
        return x

    def init_weights(self, init_linear='normal'):
        _init_weights(self, init_linear)

    def forward(self, x):
        assert len(x) == 1 or len(x) == 2  # to fit vit model
        x = x[0]
        x = self.bn_input(x)
        x = self.dropout(x)
        x = self.fc(x.view(x.size(0), -1))
        # if self.syncbn:
        x = self._forward_syncbn(self.bn_output, x)
        # else:
        #    x = self.bn_output(x)

        if self.with_norm:
            x = nn.functional.normalize(x, p=2, dim=1)
        return [x]


@NECKS.register_module
class MultiLinearNeck(nn.Module):
    '''MultiLinearNeck neck: MultiFc head
    '''

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_layers=1,
                 with_avg_pool=True):
        """
        Args:
            in_channels: int or list[int]
            out_channels: int or list[int]
            num_layers : total fc num
            with_avg_pool : input will be avgPool if True
        Returns:
            None
        Raises:
            len(in_channel) != len(out_channels)
            len(in_channel) != len(num_layers)
        """
        super(MultiLinearNeck, self).__init__()
        self.with_avg_pool = with_avg_pool
        self.num_layers = num_layers
        if with_avg_pool:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        if num_layers == 1:
            self.fc = nn.Linear(in_channels, out_channels)
        else:
            assert len(in_channels) == len(out_channels)
            assert len(in_channels) == num_layers
            self.fc = nn.ModuleList(
                [nn.Linear(i, j) for i, j in zip(in_channels, out_channels)])

    def init_weights(self, init_linear='normal'):
        _init_weights(self, init_linear)

    def forward(self, x):
        assert len(x) == 1 or len(x) == 2  # to fit vit model
        x = x[0]
        if self.with_avg_pool:
            x = self.avgpool(x)
        x = self.fc(x.view(x.size(0), -1))
        return [x]


@NECKS.register_module()
class HRFuseScales(nn.Module):
    """Fuse feature map of multiple scales in HRNet.
    Args:
        in_channels (list[int]): The input channels of all scales.
        out_channels (int): The channels of fused feature map.
            Defaults to 2048.
        norm_cfg (dict): dictionary to construct norm layers.
            Defaults to ``dict(type='BN', momentum=0.1)``.
        init_cfg (dict | list[dict], optional): Initialization config dict.
            Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``.
    """

    def __init__(self,
                 in_channels,
                 out_channels=2048,
                 norm_cfg=dict(type='BN', momentum=0.1)):
        super(HRFuseScales, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.norm_cfg = norm_cfg

        block_type = Bottleneck
        out_channels = [128, 256, 512, 1024]

        # Increase the channels on each resolution
        # from C, 2C, 4C, 8C to 128, 256, 512, 1024
        increase_layers = []
        for i in range(len(in_channels)):
            increase_layers.append(
                ResLayer(
                    block_type,
                    in_channels=in_channels[i],
                    out_channels=out_channels[i],
                    num_blocks=1,
                    stride=1,
                ))
        self.increase_layers = nn.ModuleList(increase_layers)

        # Downsample feature maps in each scale.
        downsample_layers = []
        for i in range(len(in_channels) - 1):
            downsample_layers.append(
                ConvModule(
                    in_channels=out_channels[i],
                    out_channels=out_channels[i + 1],
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    norm_cfg=self.norm_cfg,
                    bias=False,
                ))
        self.downsample_layers = nn.ModuleList(downsample_layers)

        # The final conv block before final classifier linear layer.
        self.final_layer = ConvModule(
            in_channels=out_channels[3],
            out_channels=self.out_channels,
            kernel_size=1,
            norm_cfg=self.norm_cfg,
            bias=False,
        )

    def init_weights(self, init_linear='normal'):
        _init_weights(self, init_linear)

    def forward(self, x):
        assert len(x) == len(self.in_channels)

        feat = self.increase_layers[0](x[0])
        for i in range(len(self.downsample_layers)):
            feat = self.downsample_layers[i](feat) + \
                self.increase_layers[i + 1](x[i + 1])

        return [self.final_layer(feat)]
